package com.menghao.rpc.spring;

import com.menghao.rpc.RpcConstants;
import com.menghao.rpc.provider.annotation.Provider;
import com.menghao.rpc.provider.regisiter.ProviderRepository;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;

import java.util.HashMap;
import java.util.Map;

/**
 * <p>内部系统服务提供方注册.</br>
 * <p>获取内部所有的 @Provider 接口并注册到服务提供者仓库</p>
 *
 * @author MarvelCode
 * @see ProviderRepository
 */
public class ProviderPostProcessor implements BeanPostProcessor {

    private static final Logger LOGGER = LoggerFactory.getLogger(ProviderPostProcessor.class);

    private Map<String, Class> candidates = new HashMap<>(8);

    private ProviderRepository providerRepository;

    public ProviderPostProcessor(ProviderRepository providerRepository) {
        this.providerRepository = providerRepository;
    }

    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        for (Class sourceInterface : bean.getClass().getInterfaces()) {
            // 递归遍历所有父接口
            recursiveInterface(sourceInterface, beanName, candidates);
        }
        return bean;
    }

    private void recursiveInterface(Class sourceInterface, String beanName, Map<String, Class> candidates) {
        // 接口被@Provider标识
        if (sourceInterface.getAnnotation(Provider.class) != null) {
            LOGGER.info(RpcConstants.LOG_RPC_PREFIX + "find @Provider-" + sourceInterface.getName());
            candidates.put(beanName, sourceInterface);
        }
        // 直到无父接口，递归结束
        if (sourceInterface.getInterfaces().length == 0) {
            return;
        }
        // 否则递归遍历父接口
        for (Class superInterface : sourceInterface.getInterfaces()) {
            recursiveInterface(superInterface, beanName, candidates);
        }
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        // 将符合条件的服务进行注册
        if (candidates.containsKey(beanName)) {
            Class sourceInterface = candidates.get(beanName);
            providerRepository.register(sourceInterface, beanName, bean);
        }
        return bean;
    }
}
